package rdb

import (
	"context"
	"fmt"
	"strconv"
	"sync"
	"time"

	"github.com/redis/go-redis/v9"
)

type MapLimit interface {
	TopicKey() string
	RedisClient() *redis.Client
	Add(key string, delta uint) error
	Sub(key string, delta uint) error
	Get(key string) (int64, error)
	Del(key string) error
	Reset(key string) error
	IsToLimit(key string, max int64) bool
	Count() (int64, error)
	GetAll() (map[string]int64, error)
	Scan(dest interface{}) error
	ClearAll() error
}

type _mapLimit struct {
	Topic string        //主题
	Redis *redis.Client //Redis客户端
	RWL   sync.RWMutex  //读写锁
}

func NewLimit(client *redis.Client, topic string) MapLimit {
	return &_mapLimit{
		Topic: topic,
		Redis: client,
		RWL:   sync.RWMutex{},
	}
}

// TopicKey 返回默认生成的主题key
func (l *_mapLimit) TopicKey() string {
	return fmt.Sprint("MapLimit_", l.Topic)
}

// RedisClient 返回绑定的Redis客户端
func (l *_mapLimit) RedisClient() *redis.Client {
	return l.Redis
}

func (l *_mapLimit) Add(key string, delta uint) error {
	l.RWL.Lock()
	err := l.Redis.HIncrBy(context.Background(), l.TopicKey(), key, int64(delta)).Err()
	l.RWL.Unlock()
	return err
}

func (l *_mapLimit) Sub(key string, delta uint) error {
	l.RWL.Lock()
	err := l.Redis.HIncrBy(context.Background(), l.TopicKey(), key, int64(-delta)).Err()
	l.RWL.Unlock()
	return err
}

func (l *_mapLimit) Get(key string) (int64, error) {
	l.RWL.RLock()
	defer l.RWL.RUnlock()
	result, err := l.Redis.HGet(context.Background(), l.TopicKey(), key).Int64()
	if err == nil || err == redis.Nil {
		return result, nil
	}
	return result, err
}

func (l *_mapLimit) Del(key string) error {
	l.RWL.Lock()
	err := l.Redis.HDel(context.Background(), l.TopicKey(), key).Err()
	l.RWL.Unlock()
	return err
}

func (l *_mapLimit) Reset(key string) error {
	l.RWL.Lock()
	err := l.Redis.HSet(context.Background(), l.TopicKey(), key, 0).Err()
	l.RWL.Unlock()
	return err
}

func (l *_mapLimit) IsToLimit(key string, max int64) bool {
	l.RWL.RLock()
	var result int64 = 0
	for i := 0; i < 10; i++ {
		temp, err := l.Redis.HGet(context.Background(), l.TopicKey(), key).Int64()
		if err == nil {
			result = temp
			break
		}
		if err == redis.Nil {
			result = 0
			break
		}
		time.Sleep(time.Millisecond * 100)
	}
	l.RWL.RUnlock()
	return result >= max
}

// Count 获取哈希表中字段的数量
func (l *_mapLimit) Count() (int64, error) {
	l.RWL.RLock()
	defer l.RWL.RUnlock()
	result, err := l.Redis.HLen(context.Background(), l.TopicKey()).Result()
	if err == nil || err == redis.Nil {
		return result, nil
	}
	return result, err
}

// GetAll 获取所有的key
func (l *_mapLimit) GetAll() (map[string]int64, error) {
	l.RWL.RLock()
	result, err := l.Redis.HGetAll(context.Background(), l.TopicKey()).Result()
	l.RWL.RUnlock()
	var temp = make(map[string]int64)
	for k, v := range result {
		parseInt, _ := strconv.ParseInt(fmt.Sprint(v), 10, 64)
		temp[k] = parseInt
	}
	return temp, err
}

// Scan 将结果扫描到目标结构体中
func (l *_mapLimit) Scan(dest interface{}) error {
	l.RWL.RLock()
	err := l.Redis.HGetAll(context.Background(), l.TopicKey()).Scan(dest)
	l.RWL.RUnlock()
	return err
}

// ClearAll 清空所有的记录（直接删除key）
func (l *_mapLimit) ClearAll() error {
	l.RWL.Lock()
	err := l.Redis.Del(context.Background(), l.TopicKey()).Err()
	l.RWL.Unlock()
	return err
}
